{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.1-final"
    },
    "colab": {
      "name": "VAE.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2wrkQRklOAeB"
      },
      "source": [
        "# Variational Auto Encoders using Ignite"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0Kvs_LwjOAeL"
      },
      "source": [
        "This is a tutorial on using Ignite to train neural network models, setup experiments and validate models.\n",
        "\n",
        "In this experiment, we'll be replicating [Auto-Encoding Variational Bayes](https://arxiv.org/abs/1312.6114) by Kingma and Welling. This paper uses an encoder-decoder architecture to encode images to a vector and then reconstruct the images.\n",
        "\n",
        "We want to be able to encode and reconstruct MNIST images. MNIST is the classic machine learning dataset, it contains black and white images of digits 0 to 9. There are 50000 training images and 10000 test images. The dataset comprises of image and label pairs. \n",
        "\n",
        "We'll be using PyTorch to create the model, torchvision to import data and Ignite to train and monitor the models!\n",
        "\n",
        "Please note that a lot of this code has been borrowed from [official PyTorch example](https://github.com/pytorch/examples/tree/master/vae). Similar to that it uses ReLUs and the adam optimizer, instead of sigmoids and adagrad.\n",
        "\n",
        "Let's get started!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QFH5UZFhOAeM"
      },
      "source": [
        "## Required Dependencies\n",
        "\n",
        "In this example we only need `torchvision` package, assuming that `torch` and `ignite` are already installed. We can install it using `pip`:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HVhQpUjYOAeO"
      },
      "source": [
        "```\n",
        "pip install torchvision\n",
        "```"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1s2up1PkOAeP"
      },
      "source": [
        "!pip install pytorch-ignite torchvision"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gnx9qzjIOAeQ"
      },
      "source": [
        "## Import Libraries"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9hKV-yeYOAeS"
      },
      "source": [
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "%matplotlib inline"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IMn3CAHqOAeS"
      },
      "source": [
        "We import `torch`, `nn` and `functional` modules to create our models! `DataLoader` to create iterators for the downloaded datasets.\n",
        "\n",
        "The code below also checks whether there are GPUs available on the machine and assigns the device to GPU if there are."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_dLPBDj-OAeT"
      },
      "source": [
        "import torch\n",
        "from torch.utils.data import DataLoader\n",
        "from torch import nn, optim\n",
        "from torch.nn import functional as F\n",
        "SEED = 1234\n",
        "\n",
        "torch.manual_seed(SEED)\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DaX-nwZAOAeU"
      },
      "source": [
        "`torchvision` is a library that provides multiple datasets for computer vision tasks. Below we import the following:\n",
        "\n",
        "* **MNIST**: A module to download the MNIST datasets.\n",
        "* **save_image**: Saves tensors as images.\n",
        "* **make_grid**: Takes a concatenation of tensors and makes a grid of images.\n",
        "* **ToTensor**: Converts images to Tensors.\n",
        "* **Compose**: Collects transformations. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0oKaBceaOAeU"
      },
      "source": [
        "from torchvision.datasets import MNIST\n",
        "from torchvision.utils import save_image, make_grid\n",
        "from torchvision.transforms import Compose, ToTensor"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8q_cNs_xOAeW"
      },
      "source": [
        "`Ignite` is a High-level library to help with training neural networks in PyTorch. It comes with an `Engine` to setup a training loop, various metrics, handlers and a helpful contrib section! \n",
        "\n",
        "Below we import the following:\n",
        "* **Engine**: Runs a given process_function over each batch of a dataset, emitting events as it goes.\n",
        "* **Events**: Allows users to attach functions to an `Engine` to fire functions at a specific event. Eg: `EPOCH_COMPLETED`, `ITERATION_STARTED`, etc.\n",
        "* **MeanSquaredError**: Metric to calculate mean squared error. \n",
        "* **Loss**: General metric that takes a loss function as a parameter, calculate loss over a dataset.\n",
        "* **RunningAverage**: General metric to attach to Engine during training. \n",
        "* **ModelCheckpoint**: Handler to checkpoint models."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EzibOregOAeW"
      },
      "source": [
        "from ignite.engine import Engine, Events\n",
        "from ignite.metrics import MeanSquaredError, Loss, RunningAverage"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mqtxETPkOAeW"
      },
      "source": [
        "## Processing Data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IzrrCIkaOAeX"
      },
      "source": [
        "Below the only transformation we use is to convert convert the images to Tensor, MNIST downloads the dataset on to your machine.\n",
        "\n",
        "* `train_data` is a list of tuples of image tensors and labels. `val_data` is the same, just a different number of images. \n",
        "* `image` is a 28 x 28 tensor with 1 channel, meaning a 28 x 28 grayscale image.\n",
        "* `label` is a single integer value, denoting what the image is showing."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Np63e2xCOAeX"
      },
      "source": [
        "data_transform = Compose([ToTensor()])\n",
        "\n",
        "train_data = MNIST(download=True, root=\"/tmp/mnist/\", transform=data_transform, train=True)\n",
        "val_data = MNIST(download=True, root=\"/tmp/mnist/\", transform=data_transform, train=False)\n",
        "\n",
        "image = train_data[0][0]\n",
        "label = train_data[0][1]\n",
        "\n",
        "print ('len(train_data) : ', len(train_data))\n",
        "print ('len(val_data) : ', len(val_data))\n",
        "print ('image.shape : ', image.shape)\n",
        "print ('label : ', label)\n",
        "\n",
        "img = plt.imshow(image.squeeze().numpy(), cmap='gray')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "euf13Tx6OAeX"
      },
      "source": [
        "Now let's setup iterators of the training and validation datasets. We can take advantage of PyTorch's `DataLoader` that allows us to specify the dataset, batch size, number of workers, device, and other helpful parameters. \n",
        "\n",
        "Let's see what the output of the iterators are:\n",
        "* We see that each batch consists of 32 images and their corresponding labels.\n",
        "* Examples are shuffled.\n",
        "* Data is placed on GPU if available, otherwise it uses CPU."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0_HaWmxTOAeY"
      },
      "source": [
        "kwargs = {'num_workers': 1, 'pin_memory': True} if device == 'cuda' else {}\n",
        "\n",
        "train_loader = DataLoader(train_data, batch_size=32, shuffle=True, **kwargs)\n",
        "val_loader = DataLoader(val_data, batch_size=32, shuffle=True, **kwargs)\n",
        "\n",
        "for batch in train_loader:\n",
        "    x, y = batch\n",
        "    break\n",
        "\n",
        "print ('x.shape : ', x.shape)\n",
        "print ('y.shape : ', y.shape)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GFBK6G7HOAeY"
      },
      "source": [
        "To visualize how well our model reconstruct images, let's save the above value of x as a set of images we can use to compare against the generation reconstructions from our model."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_hhtM6UhOAeY"
      },
      "source": [
        "fixed_images = x.to(device)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "m2yFi8vhOAeZ"
      },
      "source": [
        "## VAE Model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hauIEN0XOAeZ"
      },
      "source": [
        "VAE is a model comprised of fully connected layers that take a flattened image, pass them through fully connected layers reducing the image to a low dimensional vector. The vector is then passed through a mirrored set of fully connected weights from the encoding steps, to generate a vector of the same size as the input."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oQ5uwtzaOAea"
      },
      "source": [
        "class VAE(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(VAE, self).__init__()\n",
        "        self.fc1 = nn.Linear(784, 400)\n",
        "        self.fc21 = nn.Linear(400, 20)\n",
        "        self.fc22 = nn.Linear(400, 20)\n",
        "        self.fc3 = nn.Linear(20, 400)\n",
        "        self.fc4 = nn.Linear(400, 784)\n",
        "\n",
        "    def encode(self, x):\n",
        "        h1 = F.relu(self.fc1(x))\n",
        "        return self.fc21(h1), self.fc22(h1)\n",
        "\n",
        "    def reparameterize(self, mu, logvar):\n",
        "        std = torch.exp(0.5*logvar)\n",
        "        eps = torch.randn_like(std)\n",
        "        return eps.mul(std).add_(mu)\n",
        "\n",
        "    def decode(self, z):\n",
        "        h3 = F.relu(self.fc3(z))\n",
        "        return torch.sigmoid(self.fc4(h3))\n",
        "\n",
        "    def forward(self, x):\n",
        "        mu, logvar = self.encode(x)\n",
        "        z = self.reparameterize(mu, logvar)\n",
        "        return self.decode(z), mu, logvar"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mV27j9EDOAea"
      },
      "source": [
        "## Creating Model, Optimizer and Loss"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kJ6yD5fBOAeb"
      },
      "source": [
        "Below we create an instance of the VAE model. The model is placed on a device and then loss functions of Binary Cross Entropy + KL Divergence is used and Adam optimizer are setup. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0rAgYpT3OAeb"
      },
      "source": [
        "model = VAE().to(device)\n",
        "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
        "\n",
        "def kld_loss(x_pred, x, mu, logvar):\n",
        "    # see Appendix B from VAE paper:\n",
        "    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014\n",
        "    # https://arxiv.org/abs/1312.6114\n",
        "    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)\n",
        "    return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())\n",
        "\n",
        "bce_loss = nn.BCELoss(reduction='sum')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1ye6HrDeOAeb"
      },
      "source": [
        "## Training and Evaluating using Ignite"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v-QR-oALOAeb"
      },
      "source": [
        "### Trainer Engine - process_function\n",
        "\n",
        "Ignite's `Engine` allows user to define a `process_function` to process a given batch, this is applied to all the batches of the dataset. This is a general class that can be applied to train and validate models! A `process_function` has two parameters engine and batch. \n",
        "\n",
        "\n",
        "Let's walk through what the function of the trainer does:\n",
        "\n",
        "* Sets model in train mode. \n",
        "* Sets the gradients of the optimizer to zero.\n",
        "* Generate `x` from batch.\n",
        "* Flattens `x` into shape `(-1, 784)`.\n",
        "* Performs a forward pass to reconstuct `x` as `x_pred` using model and x. Model also return `mu`, `logvar`.\n",
        "* Calculates loss using `x_pred`, `x`, `logvar` and `mu`.\n",
        "* Performs a backward pass using loss to calculate gradients for the model parameters.\n",
        "* model parameters are optimized using gradients and optimizer.\n",
        "* Returns scalar loss. \n",
        "\n",
        "Below is a single operation during the trainig process. This process_function will be attached to the training engine."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VwuW-CbeOAec"
      },
      "source": [
        "def process_function(engine, batch):\n",
        "    model.train()\n",
        "    optimizer.zero_grad()\n",
        "    x, _ = batch\n",
        "    x = x.to(device)\n",
        "    x = x.view(-1, 784)\n",
        "    x_pred, mu, logvar = model(x)\n",
        "    BCE = bce_loss(x_pred, x)\n",
        "    KLD = kld_loss(x_pred, x, mu, logvar)\n",
        "    loss = BCE + KLD\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "    return loss.item(), BCE.item(), KLD.item()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p8KEf_jHOAec"
      },
      "source": [
        "### Evaluator Engine - process_function\n",
        "\n",
        "Similar to the training process function, we setup a function to evaluate a single batch. Here is what the `eval_function` does:\n",
        "\n",
        "* Sets model in eval mode.\n",
        "* Generates `x` from batch.\n",
        "* With `torch.no_grad()`, no gradients are calculated for any succeding steps.\n",
        "* Flattens `x` into shape `(-1, 784)`.\n",
        "* Performs a forward pass to reconstuct `x` as `x_pred` using model and x. Model also return  `mu`, `logvar`.\n",
        "* Returns `x_pred`, `x`, `mu` and `logvar`.\n",
        "\n",
        "Ignite suggests attaching metrics to evaluators and not trainers because during the training the model parameters are constantly changing and it is best to evaluate model on a stationary model. This information is important as there is a difference in the functions for training and evaluating. Training returns a single scalar loss. Evaluating returns `y_pred` and `y` as that output is used to calculate metrics per batch for the entire dataset.\n",
        "\n",
        "All metrics in `Ignite` require `y_pred` and `y` as outputs of the function attached to the `Engine`. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PceHJQJ4OAec"
      },
      "source": [
        "def evaluate_function(engine, batch):\n",
        "    model.eval()\n",
        "    with torch.no_grad():\n",
        "        x, _ = batch\n",
        "        x = x.to(device)\n",
        "        x = x.view(-1, 784)\n",
        "        x_pred, mu, logvar = model(x)\n",
        "        kwargs = {'mu': mu, 'logvar': logvar}\n",
        "        return x_pred, x, kwargs"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GdfA_leJOAec"
      },
      "source": [
        "### Instantiating Training and Evaluating Engines\n",
        "\n",
        "Below we create 2 engines, a `trainer` and `evaluator` using the functions defined above. We also define dictionaries to keep track of the history of the metrics on the training and validation sets. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EH494wHUOAed"
      },
      "source": [
        "trainer = Engine(process_function)\n",
        "evaluator = Engine(evaluate_function)\n",
        "training_history = {'bce': [], 'kld': [], 'mse': []}\n",
        "validation_history = {'bce': [], 'kld': [], 'mse': []}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "B_nxE8WdOAed"
      },
      "source": [
        "### Metrics - RunningAverage, MeanSquareError and Loss\n",
        "\n",
        "To start, we'll attach a metric of `RunningAverage` to track a running average of the scalar `loss`, `binary cross entropy` and `KL Divergence` output for each batch. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "22b2EySTOAee"
      },
      "source": [
        "RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'loss')\n",
        "RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'bce')\n",
        "RunningAverage(output_transform=lambda x: x[2]).attach(trainer, 'kld')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fxmDDcc7OAee"
      },
      "source": [
        "Now there are two metrics that we want to use for evaluation - `mean squared error`, `binary cross entropy` and `KL Divergence`. If you noticed earlier, out `eval_function` returns `x_pred`, `x` and a few other values, `MeanSquaredError` only expects two values per batch. \n",
        "\n",
        "For each batch, the `engine.state.output` will be `x_pred`, `x` and `kwargs`, this is why we use `output_transform` to only extract values from `engine.state.output` based on the the metric need.\n",
        "\n",
        "As for `Loss`, we pass our defined `loss_function` and simply attach it to the `evaluator` as `engine.state.output` outputs all the parameters needed for `loss_function`."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TxlHHgCjOAee"
      },
      "source": [
        "MeanSquaredError(output_transform=lambda x: [x[0], x[1]]).attach(evaluator, 'mse')\n",
        "Loss(bce_loss, output_transform=lambda x: [x[0], x[1]]).attach(evaluator, 'bce')\n",
        "Loss(kld_loss).attach(evaluator, 'kld')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AS5MFiXLOAee"
      },
      "source": [
        "### Attaching Custom Functions to Engine at specific Events\n",
        "\n",
        "Below you'll see ways to define your own custom functions and attaching them to various `Events` of the training process.\n",
        "\n",
        "The first method involves using a decorator, the syntax is simple - `@` `trainer.on(Events.EPOCH_COMPLETED)`, means that the decorated function will be attached to the `trainer` and called at the end of each epoch. \n",
        "\n",
        "The second method involves using the `add_event_handler` method of `trainer` - `trainer.add_event_handler(Events.EPOCH_COMPLETED, custom_function)`. This achieves the same result as the above.\n",
        "\n",
        "\n",
        "The function below print the loss during training at the end of each epoch. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uwPNFSwtOAee"
      },
      "source": [
        "@trainer.on(Events.EPOCH_COMPLETED)\n",
        "def print_trainer_logs(engine):\n",
        "    avg_loss = engine.state.metrics['loss']\n",
        "    avg_bce = engine.state.metrics['bce']\n",
        "    avg_kld = engine.state.metrics['kld']\n",
        "    print(\"Trainer Results - Epoch {} - Avg loss: {:.2f} Avg bce: {:.2f} Avg kld: {:.2f}\"\n",
        "          .format(engine.state.epoch, avg_loss, avg_bce, avg_kld))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s0tpWvp0OAef"
      },
      "source": [
        "The function below prints the logs of the `evaluator` and updates the history of metrics for training and validation datasets, we see that it takes parameters `DataLoader` and `mode`. Using this way we are repurposing a function and attaching it twice to the `trainer`, once to evaluate of the training dataset and other on the validation dataset."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DuSfcPc3OAeg"
      },
      "source": [
        "def print_logs(engine, dataloader, mode, history_dict):\n",
        "    evaluator.run(dataloader, max_epochs=1)\n",
        "    metrics = evaluator.state.metrics\n",
        "    avg_mse = metrics['mse']\n",
        "    avg_bce = metrics['bce']\n",
        "    avg_kld = metrics['kld']\n",
        "    avg_loss =  avg_bce + avg_kld\n",
        "    print(\n",
        "        mode + \" Results - Epoch {} - Avg mse: {:.2f} Avg loss: {:.2f} Avg bce: {:.2f} Avg kld: {:.2f}\"\n",
        "        .format(engine.state.epoch, avg_mse, avg_loss, avg_bce, avg_kld))\n",
        "    for key in evaluator.state.metrics.keys():\n",
        "        history_dict[key].append(evaluator.state.metrics[key])\n",
        "\n",
        "trainer.add_event_handler(Events.EPOCH_COMPLETED, print_logs, train_loader, 'Training', training_history)\n",
        "trainer.add_event_handler(Events.EPOCH_COMPLETED, print_logs, val_loader, 'Validation', validation_history)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "13po7FGPOAeg"
      },
      "source": [
        "The function below uses the set of images (`fixed_images`) and the VAE model to generate reconstructed images, the images are then formed into a grid, saved to your local machine and displayed in the notebook below. We attach this function to the start of the training process and at the end of each epoch, this way we'll be able to visualize how much better the model gets at reconstructing images. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2l8hhfITOAeg"
      },
      "source": [
        "def compare_images(engine, save_img=False):\n",
        "    epoch = engine.state.epoch\n",
        "    reconstructed_images = model(fixed_images.view(-1, 784))[0].view(-1, 1, 28, 28)\n",
        "    comparison = torch.cat([fixed_images, reconstructed_images])\n",
        "    if save_img:\n",
        "        save_image(comparison.detach().cpu(), 'reconstructed_epoch_' + str(epoch) + '.png', nrow=8)\n",
        "    comparison_image = make_grid(comparison.detach().cpu(), nrow=8)\n",
        "    fig = plt.figure(figsize=(5, 5));\n",
        "    output = plt.imshow(comparison_image.permute(1, 2, 0));\n",
        "    plt.title('Epoch ' + str(epoch));\n",
        "    plt.show();\n",
        "\n",
        "trainer.add_event_handler(Events.STARTED, compare_images, save_img=False)\n",
        "trainer.add_event_handler(Events.EPOCH_COMPLETED(every=5), compare_images, save_img=False)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v1RsmwNROAeh"
      },
      "source": [
        "### Run Engine\n",
        "\n",
        "Next, we'll run the `trainer` for 20 epochs and monitor results."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OgaJSdpVOAeh"
      },
      "source": [
        "e = trainer.run(train_loader, max_epochs=20)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4dJ_jY4_OAei"
      },
      "source": [
        "### Plotting Results\n",
        "\n",
        "Below we see plot the metrics collected on the training and validation sets. We plot the history of `Binary Cross Entropy`, `Mean Squared Error` and `KL Divergence`."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pu9ab4LUOAei"
      },
      "source": [
        "plt.plot(range(20), training_history['bce'], 'dodgerblue', label='training')\n",
        "plt.plot(range(20), validation_history['bce'], 'orange', label='validation')\n",
        "plt.xlim(0, 20);\n",
        "plt.xlabel('Epoch')\n",
        "plt.ylabel('BCE')\n",
        "plt.title('Binary Cross Entropy on Training/Validation Set')\n",
        "plt.legend();"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XIuOnrwgOAei"
      },
      "source": [
        "plt.plot(range(20), training_history['kld'], 'dodgerblue', label='training')\n",
        "plt.plot(range(20), validation_history['kld'], 'orange', label='validation')\n",
        "plt.xlim(0, 20);\n",
        "plt.xlabel('Epoch')\n",
        "plt.ylabel('KLD')\n",
        "plt.title('KL Divergence on Training/Validation Set')\n",
        "plt.legend();"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1gN0gWiKOAei"
      },
      "source": [
        "plt.plot(range(20), training_history['mse'], 'dodgerblue', label='training')\n",
        "plt.plot(range(20), validation_history['mse'], 'orange', label='validation')\n",
        "plt.xlim(0, 20);\n",
        "plt.xlabel('Epoch')\n",
        "plt.ylabel('MSE')\n",
        "plt.title('Mean Squared Error on Training/Validation Set')\n",
        "plt.legend();"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}
